// Copyright (c) 2023 Huawei Device Co., Ltd.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

//! Asynchronous `Connector` trait and `HttpConnector` implementation.

use core::future::Future;
use std::error::Error;
use std::io;

use crate::util::ConnectorConfig;
use crate::{AsyncRead, AsyncWrite, TcpStream, Uri};

/// `Connector` trait used by `async_impl::Client`. `Connector` provides
/// asynchronous connection establishment interfaces.
pub trait Connector {
    /// Streams that this connector produces.
    type Stream: AsyncRead + AsyncWrite + Unpin + Sync + Send + 'static;
    /// Possible errors that this connector may generate when attempting to
    /// connect.
    type Error: Into<Box<dyn Error + Sync + Send>>;
    /// Futures generated by this connector when attempting to create a stream.
    type Future: Future<Output = Result<Self::Stream, Self::Error>> + Unpin + Sync + Send + 'static;

    /// Attempts to establish a connection.
    fn connect(&self, uri: &Uri) -> Self::Future;
}

/// Connector for creating HTTP or HTTPS connections asynchronously.
///
/// `HttpConnector` implements `async_impl::Connector` trait.
#[derive(Default)]
pub struct HttpConnector {
    config: ConnectorConfig,
}

impl HttpConnector {
    /// Creates a new `HttpConnector` with a `ConnectorConfig`.
    pub(crate) fn new(config: ConnectorConfig) -> HttpConnector {
        HttpConnector { config }
    }
}

async fn tcp_stream(addr: &str) -> io::Result<TcpStream> {
    TcpStream::connect(addr)
        .await
        .and_then(|stream| match stream.set_nodelay(true) {
            Ok(()) => Ok(stream),
            Err(e) => Err(e),
        })
}

#[cfg(not(feature = "__tls"))]
mod no_tls {
    use core::future::Future;
    use core::pin::Pin;
    use std::io::Error;

    use super::{tcp_stream, Connector, HttpConnector};
    use crate::{TcpStream, Uri};

    impl Connector for HttpConnector {
        type Stream = TcpStream;
        type Error = Error;
        type Future =
            Pin<Box<dyn Future<Output = Result<Self::Stream, Self::Error>> + Sync + Send>>;

        fn connect(&self, uri: &Uri) -> Self::Future {
            // Checks if this uri need be proxied.
            let addr = self
                .config
                .proxies
                .match_proxy(uri)
                .map(|proxy| proxy.via_proxy(uri).authority().unwrap().to_string())
                .unwrap_or(uri.authority().unwrap().to_string());

            Box::pin(async move { tcp_stream(&addr).await })
        }
    }
}

#[cfg(feature = "__tls")]
mod tls {
    use core::future::Future;
    use core::pin::Pin;
    use std::io::{Error, ErrorKind, Write};

    use super::{tcp_stream, Connector, HttpConnector};
    use crate::async_impl::ssl_stream::{AsyncSslStream, MixStream};
    use crate::error::CauseMessage;
    use crate::{AsyncReadExt, AsyncWriteExt, Scheme, TcpStream, Uri};

    impl Connector for HttpConnector {
        type Stream = MixStream<TcpStream>;
        type Error = Error;
        type Future =
            Pin<Box<dyn Future<Output = Result<Self::Stream, Self::Error>> + Sync + Send>>;

        fn connect(&self, uri: &Uri) -> Self::Future {
            // Make sure all parts of uri is accurate.
            let mut addr = uri.authority().unwrap().to_string();
            let host = uri.host().unwrap().as_str().to_string();
            let port = uri.port().unwrap().as_u16().unwrap();
            let mut auth = None;
            let mut is_proxy = false;

            if let Some(proxy) = self.config.proxies.match_proxy(uri) {
                addr = proxy.via_proxy(uri).authority().unwrap().to_string();
                auth = proxy
                    .intercept
                    .proxy_info()
                    .basic_auth
                    .as_ref()
                    .and_then(|v| v.to_str().ok());
                is_proxy = true;
            }

            let host_name = uri
                .host()
                .map(|host| host.to_string())
                .unwrap_or_else(|| "no host in uri".to_string());

            match *uri.scheme().unwrap() {
                Scheme::HTTP => {
                    Box::pin(async move { Ok(MixStream::Http(tcp_stream(&addr).await?)) })
                }
                Scheme::HTTPS => {
                    let config = self.config.tls.clone();
                    Box::pin(async move {
                        let mut tcp = tcp_stream(&addr).await?;

                        if is_proxy {
                            tcp = tunnel(tcp, host, port, auth).await?;
                        };

                        let mut stream = config
                            .ssl_new(&host_name)
                            .and_then(|ssl| AsyncSslStream::new(ssl.into_inner(), tcp))
                            .map_err(|e| Error::new(ErrorKind::Other, e))?;

                        Pin::new(&mut stream)
                            .connect()
                            .await
                            .map_err(|e| Error::new(ErrorKind::Other, e))?;
                        Ok(MixStream::Https(stream))
                    })
                }
            }
        }
    }

    async fn tunnel(
        mut conn: TcpStream,
        host: String,
        port: u16,
        auth: Option<String>,
    ) -> Result<TcpStream, Error> {
        let mut req = Vec::new();

        write!(
            &mut req,
            "CONNECT {host}:{port} HTTP/1.1\r\nHost: {host}:{port}\r\n"
        )?;

        if let Some(value) = auth {
            write!(&mut req, "Proxy-Authorization: Basic {value}\r\n")?;
        }

        write!(&mut req, "\r\n")?;

        conn.write_all(&req).await?;

        let mut buf = [0; 8192];
        let mut pos = 0;

        loop {
            let n = conn.read(&mut buf[pos..]).await?;

            if n == 0 {
                return Err(other_io_error("error receiving from proxy"));
            }

            pos += n;
            let resp = &buf[..pos];
            if resp.starts_with(b"HTTP/1.1 200") {
                if resp.ends_with(b"\r\n\r\n") {
                    return Ok(conn);
                }
                if pos == buf.len() {
                    return Err(other_io_error("proxy headers too long for tunnel"));
                }
            } else if resp.starts_with(b"HTTP/1.1 407") {
                return Err(other_io_error("proxy authentication required"));
            } else {
                return Err(other_io_error("unsuccessful tunnel"));
            }
        }
    }

    fn other_io_error(msg: &str) -> Error {
        Error::new(ErrorKind::Other, CauseMessage::new(msg))
    }
}
